{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Interacting with Jukebox",
      "provenance": [],
      "collapsed_sections": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "sAdFGF-bqVMY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install git+https://github.com/openai/jukebox.git"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uq8uLwZCn0BV",
        "colab_type": "text"
      },
      "source": [
        "IMPORTANT NOTE ON SYSTEM REQUIREMENTS:\n",
        "\n",
        "If you are connecting to a hosted runtime, make sure it has a P100 GPU (optionally run !nvidia-smi to confirm). Go to Edit>Notebook Settings to set this.\n",
        "\n",
        "CoLab may first assign you a lower memory machine if you are using a hosted runtime.  If so, the first time you try to load the 5B model, it will run out of memory, and then you'll be prompted to restart with more memory (then return to the top of this CoLab).  If you continue to have memory issues after this (or run into issues on your own home setup), switch to the 1B model.\n",
        "\n",
        "If you are using a local GPU, we recommend V100 or P100 with 16GB GPU memory for best performance. For GPU’s with less memory, we recommend using the 1B model and a smaller batch size throughout.  \n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8qEqdj8u0gdN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!nvidia-smi"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "taDHgk1WCC_C",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import jukebox\n",
        "import torch as t\n",
        "import librosa\n",
        "import os\n",
        "from IPython.display import Audio\n",
        "from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model\n",
        "from jukebox.hparams import Hyperparams, setup_hparams\n",
        "from jukebox.sample import sample_single_window, _sample, \\\n",
        "                           sample_partial_window, upsample\n",
        "from jukebox.utils.dist_utils import setup_dist_from_mpi\n",
        "from jukebox.utils.torch_utils import empty_cache\n",
        "rank, local_rank, device = setup_dist_from_mpi()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "89FftI5kc-Az",
        "colab_type": "text"
      },
      "source": [
        "# Sample from the 5B or 1B Lyrics Model\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "65aR2OZxmfzq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = \"5b_lyrics\" # or \"1b_lyrics\"     \n",
        "hps = Hyperparams()\n",
        "hps.sr = 44100\n",
        "hps.n_samples = 3 if model=='5b_lyrics' else 8\n",
        "hps.name = 'samples'\n",
        "chunk_size = 16 if model==\"5b_lyrics\" else 32\n",
        "max_batch_size = 3 if model==\"5b_lyrics\" else 16\n",
        "hps.levels = 3\n",
        "hps.hop_fraction = [.5,.5,.125]\n",
        "\n",
        "vqvae, *priors = MODELS[model]\n",
        "vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)\n",
        "top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JYKiwkzy0Iyf",
        "colab_type": "text"
      },
      "source": [
        "Specify your choice of artist, genre, lyrics, and length of musical sample. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-sY9aGHcZP-u",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sample_length_in_seconds = 60          # Full length of musical sample to generate - we find songs in the 1 to 4 minute\n",
        "                                       # range work well, with generation time proportional to sample length.  \n",
        "                                       # This total length affects how quickly the model \n",
        "                                       # progresses through lyrics (model also generates differently\n",
        "                                       # depending on if it thinks it's in the beginning, middle, or end of sample)\n",
        "\n",
        "hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens\n",
        "assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qD0qxQeLaTR0",
        "colab": {}
      },
      "source": [
        "metas = [dict(artist = \"Zac Brown Band\",\n",
        "            genre = \"Country\",\n",
        "            total_length = hps.sample_length,\n",
        "            offset = 0,\n",
        "            lyrics = \"\"\"I met a traveller from an antique land,\n",
        "            Who said—“Two vast and trunkless legs of stone\n",
        "            Stand in the desert. . . . Near them, on the sand,\n",
        "            Half sunk a shattered visage lies, whose frown,\n",
        "            And wrinkled lip, and sneer of cold command,\n",
        "            Tell that its sculptor well those passions read\n",
        "            Which yet survive, stamped on these lifeless things,\n",
        "            The hand that mocked them, and the heart that fed;\n",
        "            And on the pedestal, these words appear:\n",
        "            My name is Ozymandias, King of Kings;\n",
        "            Look on my Works, ye Mighty, and despair!\n",
        "            Nothing beside remains. Round the decay\n",
        "            Of that colossal Wreck, boundless and bare\n",
        "            The lone and level sands stretch far away\n",
        "            \"\"\",\n",
        "            ),\n",
        "          ] * hps.n_samples\n",
        "labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6PHC1XnEfV4Y",
        "colab_type": "text"
      },
      "source": [
        "Optionally adjust the sampling temperature (we've found .98 or .99 to be our favorite).  \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eNwKyqYraTR9",
        "colab": {}
      },
      "source": [
        "sampling_temperature = .98\n",
        "\n",
        "lower_batch_size = 16\n",
        "max_batch_size = 3 if model == \"5b_lyrics\" else 16\n",
        "lower_level_chunk_size = 32\n",
        "chunk_size = 16 if model == \"5b_lyrics\" else 32\n",
        "sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,\n",
        "                        chunk_size=lower_level_chunk_size),\n",
        "                    dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,\n",
        "                         chunk_size=lower_level_chunk_size),\n",
        "                    dict(temp=sampling_temperature, fp16=True, \n",
        "                         max_batch_size=max_batch_size, chunk_size=chunk_size)]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S3j0gT3HfrRD",
        "colab_type": "text"
      },
      "source": [
        "Now we're ready to sample from the model. We'll generate the top level (2) first, followed by the first upsampling (level 1), and the second upsampling (0).  In this CoLab we load the top prior separately from the upsamplers, because of memory concerns on the hosted runtimes. If you are using a local machine, you can also load all models directly with make_models, and then use sample.py's ancestral_sampling to put this all in one step.\n",
        "\n",
        "After each level, we decode to raw audio and save the audio files.   \n",
        "\n",
        "This next cell will take a while (approximately 10 minutes per 20 seconds of music sample)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2nET_YBEopyp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]\n",
        "zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-gxY9aqHqfLJ",
        "colab_type": "text"
      },
      "source": [
        "Listen to the results from the top level (note this will sound very noisy until we do the upsampling stage).  You may have more generated samples, depending on the batch size you requested."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TPZENDGZqOOb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Audio(f'{hps.name}/level_2/item_0.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EJc3bQxmusc6",
        "colab_type": "text"
      },
      "source": [
        "We are now done with the large top_prior model, and instead load the upsamplers."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W5VLX0zRapIm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set this False if you are on a local machine that has enough memory (this allows you to do the\n",
        "# lyrics alignment visualization during the upsampling stage). For a hosted runtime, \n",
        "# we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.\n",
        "if True:\n",
        "  del top_prior\n",
        "  empty_cache()\n",
        "  top_prior=None\n",
        "upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]\n",
        "labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eH_jUhGDprAt",
        "colab_type": "text"
      },
      "source": [
        "Please note: this next upsampling step will take several hours.  At the free tier, Google CoLab lets you run for 12 hours.  As the upsampling is completed, samples will appear in the Files tab (you can access this at the left of the CoLab), under \"samples\" (or whatever hps.name is currently).  Level 1 is the partially upsampled version, and then Level 0 is fully completed."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9lkJgLolpZ6w",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3SJgBYJPri55",
        "colab_type": "text"
      },
      "source": [
        "Listen to your final sample!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2ip2PPE0rgAb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Audio(f'{hps.name}/level_0/item_0.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8JAgFxytwrLG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "del upsamplers\n",
        "empty_cache()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LpvvFH85bbBC",
        "colab_type": "text"
      },
      "source": [
        "# Co-Composing with the 5B or 1B Lyrics Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nFDROuS7gFQY",
        "colab_type": "text"
      },
      "source": [
        "For more control over the generations, try co-composing with either the 5B or 1B Lyrics Models.  Again, specify your artist, genre, and lyrics. However, now instead of generating the entire sample, the model will return 3 short options for the opening of the piece (or up to 16 options if you use the 1B model instead).  Choose your favorite, and then continue the loop, for as long as you like.  Throughout these steps, you'll be listening to the audio at the top prior level, which means it will sound quite noisy.  When you are satisfied with your co-creation, continue on through the upsampling section. This will render the piece in higher audio quality.\n",
        "\n",
        "NOTE: CoLab will first assign you a lower memory machine if you are using a hosted runtime.  The next cell will run out of memory, and then you'll be prompted to restart with more memory (then return to the top of this CoLab).  If you continue to have memory issues after this (or run into issues on your own home setup), switch to the 1B model. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3y-q8ifhGBlU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = \"5b_lyrics\" # or \"1b_lyrics\"\n",
        "hps = Hyperparams()\n",
        "hps.sr = 44100\n",
        "hps.n_samples = 3 if model=='5b_lyrics' else 16\n",
        "hps.name = 'co_composer'\n",
        "hps.sample_length = 1048576 if model==\"5b_lyrics\" else 786432 \n",
        "chunk_size = 16 if model==\"5b_lyrics\" else 32\n",
        "max_batch_size = 3 if model==\"5b_lyrics\" else 16\n",
        "hps.hop_fraction = [.5, .5, .125] \n",
        "hps.levels = 3\n",
        "\n",
        "vqvae, *priors = MODELS[model]\n",
        "vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = hps.sample_length)), device)\n",
        "top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "68hz4x7igq0c",
        "colab_type": "text"
      },
      "source": [
        "Choose your artist, genre, and lyrics here!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QDMvH_1zUHo6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "total_sample_length_in_seconds = 120\n",
        "metas = [dict(artist = \"Zac Brown Band\",\n",
        "            genre = \"Country\",\n",
        "            total_length = total_sample_length_in_seconds * hps.sr,\n",
        "            offset = 0,\n",
        "            lyrics = \"\"\"I met a traveller from an antique land,\n",
        "            Who said—“Two vast and trunkless legs of stone\n",
        "            Stand in the desert. . . . Near them, on the sand,\n",
        "            Half sunk a shattered visage lies, whose frown,\n",
        "            And wrinkled lip, and sneer of cold command,\n",
        "            Tell that its sculptor well those passions read\n",
        "            Which yet survive, stamped on these lifeless things,\n",
        "            The hand that mocked them, and the heart that fed;\n",
        "            And on the pedestal, these words appear:\n",
        "            My name is Ozymandias, King of Kings;\n",
        "            Look on my Works, ye Mighty, and despair!\n",
        "            Nothing beside remains. Round the decay\n",
        "            Of that colossal Wreck, boundless and bare\n",
        "            The lone and level sands stretch far away\n",
        "            \"\"\",\n",
        "            ),\n",
        "          ] * hps.n_samples\n",
        "labels = top_prior.labeller.get_batch_labels(metas, 'cuda')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9onZMEXh34f",
        "colab_type": "text"
      },
      "source": [
        "## Generate 3 options for the start of the song\n",
        "\n",
        "Initial generation is set to be 4 seconds long, but feel free to change this"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c6peEj8I_HHO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def seconds_to_tokens(sec, sr, prior, chunk_size):\n",
        "  tokens = sec * hps.sr // prior.raw_to_tokens\n",
        "  tokens = ((tokens // chunk_size) + 1) * chunk_size\n",
        "  assert tokens <= prior.n_ctx, 'Choose a shorter generation length to stay within the top prior context'\n",
        "  return tokens"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2gn2GXt3zt3y",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "initial_generation_in_seconds = 4\n",
        "tokens_to_sample = seconds_to_tokens(initial_generation_in_seconds, hps.sr, top_prior, chunk_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U0zcWcMoiigl",
        "colab_type": "text"
      },
      "source": [
        "Change the sampling temperature if you like (higher is more random).  Our favorite is in the range .98 to .995"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NHbH68H7VMeO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sampling_temperature = .98\n",
        "sampling_kwargs = dict(temp=sampling_temperature, fp16=True,\n",
        "                       max_batch_size=max_batch_size, chunk_size=chunk_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JGZEPe-WTt4g",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(3)]\n",
        "zs=sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
        "x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mveN4Be8jK2J",
        "colab_type": "text"
      },
      "source": [
        "Listen to your generated samples, and then pick a favorite. If you don't like any, go back and rerun the cell above. \n",
        "\n",
        "** NOTE this is at the noisy top level, upsample fully (in the next section) to hear the final audio version"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LrJSGMhUOhZg",
        "colab": {}
      },
      "source": [
        "for i in range(hps.n_samples):\n",
        "  librosa.output.write_wav(f'noisy_top_level_generation_{i}.wav', x[i], sr=44100)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rQ4ersQ5OhZr",
        "colab": {}
      },
      "source": [
        "Audio('noisy_top_level_generation_0.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "-GdqzrGkOhZv",
        "colab": {}
      },
      "source": [
        "Audio('noisy_top_level_generation_1.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "gE5S8hyZOhZy",
        "colab": {}
      },
      "source": [
        "Audio('noisy_top_level_generation_2.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t2-mEJaqZfuS",
        "colab_type": "text"
      },
      "source": [
        "If you don't like any of the options, return a few cells back to \"Sample a few options...\" and rerun from there."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7CzSiv0MmFP",
        "colab_type": "text"
      },
      "source": [
        "## Choose your favorite sample and request longer generation\n",
        "\n",
        "---\n",
        "\n",
        "(Repeat from here)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "j_XFtVi99CIY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "my_choice=0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pgk3sHHBLYoq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "zs[2]=zs[2][my_choice].repeat(hps.n_samples,1)\n",
        "t.save(zs, 'zs-checkpoint2.t')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W8Rd9xxm565S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set to True to load the previous checkpoint:\n",
        "if False:\n",
        "  zs=t.load('zs-checkpoint2.t') "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k12xjMgHkRGP",
        "colab_type": "text"
      },
      "source": [
        "Choose the length of the continuation.  The 1B model can generate up to 17 second samples and the 5B up to 23 seconds, but you'll want to pick a shorter continuation length so that it will be able to look back at what you've generated already.  Here we've chosen 4 seconds."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h3_-0a07kHHG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "continue_generation_in_seconds=4\n",
        "tokens_to_sample = seconds_to_tokens(continue_generation_in_seconds, hps.sr, top_prior, chunk_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GpPG3Ifqk8ue",
        "colab_type": "text"
      },
      "source": [
        "The next step asks the top prior to generate more of the sample. It'll take up to a few minutes, depending on the sample length you request."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YoHkeSTaEyLj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "zs = sample_partial_window(zs, labels, sampling_kwargs, 2, top_prior, tokens_to_sample, hps)\n",
        "x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ymhUqEdhleEi",
        "colab_type": "text"
      },
      "source": [
        "Now listen to the longer versions of the sample you selected, and again choose a favorite sample.  If you don't like any, return back to the cell where you can load the checkpoint, and continue again from there.\n",
        "\n",
        "When the samples start getting long, you might not always want to listen from the start, so change the playback start time later on if you like."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2H1LNLTa_R6a",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "playback_start_time_in_seconds = 0 "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "r4SBGAmsnJtH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "for i in range(hps.n_samples):\n",
        "  librosa.output.write_wav(f'top_level_continuation_{i}.wav', x[i][playback_start_time_in_seconds*44100:], sr=44100)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2WeyE5Qtnmeo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Audio('top_level_continuation_0.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BKtfEtcaazXE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Audio('top_level_continuation_1.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7yrlS0XwK2S0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Audio('top_level_continuation_2.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-OJT704dvnGv",
        "colab_type": "text"
      },
      "source": [
        "To make a longer song, return back to \"Choose your favorite sample\" and loop through that again"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RzCrkCZJvUcQ",
        "colab_type": "text"
      },
      "source": [
        "# Upsample Co-Composition to Higher Audio Quality"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4MPgukwMmB0p",
        "colab_type": "text"
      },
      "source": [
        "Choose your favorite sample from your latest group of generations.  (If you haven't already gone through the Co-Composition block, make sure to do that first so you have a generation to upsample)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yv-pNNPHBQYC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "choice = 0\n",
        "select_best_sample = True  # Set false if you want to upsample all your samples \n",
        "                           # upsampling sometimes yields subtly different results on multiple runs,\n",
        "                           # so this way you can choose your favorite upsampling"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v17cEAqyCgfo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "if select_best_sample:\n",
        "  zs[2]=zs[2][choice].repeat(zs[2].shape[0],1)\n",
        "\n",
        "t.save(zs, 'zs-top-level-final.t')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0YjK-Ac0tBfu",
        "colab_type": "text"
      },
      "source": [
        "Note: If you are using a CoLab hosted runtime on the free tier, you may want to download this zs-top-level-final.t file, and then restart an instance and load it in the next cell.  The free tier will last a maximum of 12 hours, and the upsampling stage can take many hours, depending on how long a sample you have generated."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qqlR9368s3jJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "if False:\n",
        "  zs = t.load('zs-top-level-final.t')\n",
        "\n",
        "assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'\n",
        "hps.sample_length = zs[2].shape[1]*top_prior.raw_to_tokens"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jzHwF_iqgIWM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set this False if you are on a local machine that has enough memory (this allows you to do the\n",
        "# lyrics alignment visualization). For a hosted runtime, we'll need to go ahead and delete the top_prior\n",
        "# if you are using the 5b_lyrics model.\n",
        "if True:\n",
        "  del top_prior\n",
        "  empty_cache()\n",
        "  top_prior=None\n",
        "\n",
        "upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q22Ier6YSkKS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
        "                    dict(temp=0.99, fp16=True, max_batch_size=16, chunk_size=32),\n",
        "                    None]\n",
        "\n",
        "if type(labels)==dict:\n",
        "  labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [labels] "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T1MCa9_jnjpf",
        "colab_type": "text"
      },
      "source": [
        "This next step upsamples 2 levels.  The level_1 samples will be available after around one hour (depending on the length of your sample) and are saved under {hps.name}/level_0/item_0.wav, while the fully upsampled level_0 will likely take 4-12 hours. You can access the wav files down below, or using the \"Files\" panel at the left of this CoLab.\n",
        "\n",
        "(Please note, if you are using this CoLab on Google's free tier, you may want to download intermediate steps as the connection will last for a maximum 12 hours.)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NcNT5qIRMmHq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W2jTYLPBc29M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Audio(f'{hps.name}/level_0/item_0.wav')"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}